/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.messaging.simp.stomp;

import java.io.ByteArrayOutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;

import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpLogging;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderInitializer;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.InvalidMimeTypeException;
import org.springframework.util.MultiValueMap;

/**
 * Decodes one or more STOMP frames contained in a {@link ByteBuffer}.
 *
 * <p>An attempt is made to read all complete STOMP frames from the buffer, which
 * could be zero, one, or more. If there is any left-over content, i.e. an incomplete
 * STOMP frame, at the end the buffer is reset to point to the beginning of the
 * partial content. The caller is then responsible for dealing with that
 * incomplete content by buffering until there is more input available.
 *
 * @author Andy Wilkinson
 * @author Rossen Stoyanchev
 * @since 4.0
 */
public class StompDecoder {

    static final byte[] HEARTBEAT_PAYLOAD = new byte[]{'\n'};

    private static final Log logger = SimpLogging.forLogName(StompDecoder.class);

    @Nullable
    private MessageHeaderInitializer headerInitializer;


    /**
     * Configure a {@link MessageHeaderInitializer} to apply to the headers of
     * {@link Message Messages} from decoded STOMP frames.
     */
    public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitializer) {
        this.headerInitializer = headerInitializer;
    }

    /**
     * Return the configured {@code MessageHeaderInitializer}, if any.
     */
    @Nullable
    public MessageHeaderInitializer getHeaderInitializer() {
        return this.headerInitializer;
    }


    /**
     * Decodes one or more STOMP frames from the given {@code ByteBuffer} into a
     * list of {@link Message Messages}. If the input buffer contains partial STOMP frame
     * content, or additional content with a partial STOMP frame, the buffer is
     * reset and {@code null} is returned.
     *
     * @param byteBuffer the buffer to decode the STOMP frame from
     * @return the decoded messages, or an empty list if none
     * @throws StompConversionException raised in case of decoding issues
     */
    public List<Message<byte[]>> decode(ByteBuffer byteBuffer) {
        return decode(byteBuffer, null);
    }

    /**
     * Decodes one or more STOMP frames from the given {@code buffer} and returns
     * a list of {@link Message Messages}.
     * <p>If the given ByteBuffer contains only partial STOMP frame content and no
     * complete STOMP frames, an empty list is returned, and the buffer is reset to
     * to where it was.
     * <p>If the buffer contains one ore more STOMP frames, those are returned and
     * the buffer reset to point to the beginning of the unused partial content.
     * <p>The output partialMessageHeaders map is used to store successfully parsed
     * headers in case of partial content. The caller can then check if a
     * "content-length" header was read, which helps to determine how much more
     * content is needed before the next attempt to decode.
     *
     * @param byteBuffer            the buffer to decode the STOMP frame from
     * @param partialMessageHeaders an empty output map that will store the last
     *                              successfully parsed partialMessageHeaders in case of partial message content
     *                              in cases where the partial buffer ended with a partial STOMP frame
     * @return the decoded messages, or an empty list if none
     * @throws StompConversionException raised in case of decoding issues
     */
    public List<Message<byte[]>> decode(ByteBuffer byteBuffer,
                                        @Nullable MultiValueMap<String, String> partialMessageHeaders) {

        List<Message<byte[]>> messages = new ArrayList<>();
        while (byteBuffer.hasRemaining()) {
            Message<byte[]> message = decodeMessage(byteBuffer, partialMessageHeaders);
            if (message != null) {
                messages.add(message);
            } else {
                break;
            }
        }
        return messages;
    }

    /**
     * Decode a single STOMP frame from the given {@code buffer} into a {@link Message}.
     */
    @Nullable
    private Message<byte[]> decodeMessage(ByteBuffer byteBuffer, @Nullable MultiValueMap<String, String> headers) {
        Message<byte[]> decodedMessage = null;
        skipLeadingEol(byteBuffer);

        // Explicit mark/reset access via Buffer base type for compatibility
        // with covariant return type on JDK 9's ByteBuffer...
        Buffer buffer = byteBuffer;
        buffer.mark();

        String command = readCommand(byteBuffer);
        if (command.length() > 0) {
            StompHeaderAccessor headerAccessor = null;
            byte[] payload = null;
            if (byteBuffer.remaining() > 0) {
                StompCommand stompCommand = StompCommand.valueOf(command);
                headerAccessor = StompHeaderAccessor.create(stompCommand);
                initHeaders(headerAccessor);
                readHeaders(byteBuffer, headerAccessor);
                payload = readPayload(byteBuffer, headerAccessor);
            }
            if (payload != null) {
                if (payload.length > 0) {
                    StompCommand stompCommand = headerAccessor.getCommand();
                    if (stompCommand != null && !stompCommand.isBodyAllowed()) {
                        throw new StompConversionException(stompCommand +
                                " shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
                    }
                }
                headerAccessor.updateSimpMessageHeadersFromStompHeaders();
                headerAccessor.setLeaveMutable(true);
                decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
                if (logger.isTraceEnabled()) {
                    logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(payload));
                }
            } else {
                logger.trace("Incomplete frame, resetting input buffer...");
                if (headers != null && headerAccessor != null) {
                    String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
                    @SuppressWarnings("unchecked")
                    MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name);
                    if (map != null) {
                        headers.putAll(map);
                    }
                }
                buffer.reset();
            }
        } else {
            StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
            initHeaders(headerAccessor);
            headerAccessor.setLeaveMutable(true);
            decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
            if (logger.isTraceEnabled()) {
                logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(null));
            }
        }

        return decodedMessage;
    }

    private void initHeaders(StompHeaderAccessor headerAccessor) {
        MessageHeaderInitializer initializer = getHeaderInitializer();
        if (initializer != null) {
            initializer.initHeaders(headerAccessor);
        }
    }

    /**
     * Skip one ore more EOL characters at the start of the given ByteBuffer.
     * Those are STOMP heartbeat frames.
     */
    protected void skipLeadingEol(ByteBuffer byteBuffer) {
        while (true) {
            if (!tryConsumeEndOfLine(byteBuffer)) {
                break;
            }
        }
    }

    private String readCommand(ByteBuffer byteBuffer) {
        ByteArrayOutputStream command = new ByteArrayOutputStream(256);
        while (byteBuffer.remaining() > 0 && !tryConsumeEndOfLine(byteBuffer)) {
            command.write(byteBuffer.get());
        }
        return new String(command.toByteArray(), StandardCharsets.UTF_8);
    }

    private void readHeaders(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) {
        while (true) {
            ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
            boolean headerComplete = false;
            while (byteBuffer.hasRemaining()) {
                if (tryConsumeEndOfLine(byteBuffer)) {
                    headerComplete = true;
                    break;
                }
                headerStream.write(byteBuffer.get());
            }
            if (headerStream.size() > 0 && headerComplete) {
                String header = new String(headerStream.toByteArray(), StandardCharsets.UTF_8);
                int colonIndex = header.indexOf(':');
                if (colonIndex <= 0) {
                    if (byteBuffer.remaining() > 0) {
                        throw new StompConversionException("Illegal header: '" + header +
                                "'. A header must be of the form <name>:[<value>].");
                    }
                } else {
                    String headerName = unescape(header.substring(0, colonIndex));
                    String headerValue = unescape(header.substring(colonIndex + 1));
                    try {
                        headerAccessor.addNativeHeader(headerName, headerValue);
                    } catch (InvalidMimeTypeException ex) {
                        if (byteBuffer.remaining() > 0) {
                            throw ex;
                        }
                    }
                }
            } else {
                break;
            }
        }
    }

    /**
     * See STOMP Spec 1.2:
     * <a href="http://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>.
     */
    private String unescape(String inString) {
        StringBuilder sb = new StringBuilder(inString.length());
        int pos = 0;  // position in the old string
        int index = inString.indexOf('\\');

        while (index >= 0) {
            sb.append(inString.substring(pos, index));
            if (index + 1 >= inString.length()) {
                throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
            }
            Character c = inString.charAt(index + 1);
            if (c == 'r') {
                sb.append('\r');
            } else if (c == 'n') {
                sb.append('\n');
            } else if (c == 'c') {
                sb.append(':');
            } else if (c == '\\') {
                sb.append('\\');
            } else {
                // should never happen
                throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
            }
            pos = index + 2;
            index = inString.indexOf('\\', pos);
        }

        sb.append(inString.substring(pos));
        return sb.toString();
    }

    @Nullable
    private byte[] readPayload(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) {
        Integer contentLength;
        try {
            contentLength = headerAccessor.getContentLength();
        } catch (NumberFormatException ex) {
            if (logger.isDebugEnabled()) {
                logger.debug("Ignoring invalid content-length: '" + headerAccessor);
            }
            contentLength = null;
        }

        if (contentLength != null && contentLength >= 0) {
            if (byteBuffer.remaining() > contentLength) {
                byte[] payload = new byte[contentLength];
                byteBuffer.get(payload);
                if (byteBuffer.get() != 0) {
                    throw new StompConversionException("Frame must be terminated with a null octet");
                }
                return payload;
            } else {
                return null;
            }
        } else {
            ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
            while (byteBuffer.remaining() > 0) {
                byte b = byteBuffer.get();
                if (b == 0) {
                    return payload.toByteArray();
                } else {
                    payload.write(b);
                }
            }
        }
        return null;
    }

    /**
     * Try to read an EOL incrementing the buffer position if successful.
     *
     * @return whether an EOL was consumed
     */
    private boolean tryConsumeEndOfLine(ByteBuffer byteBuffer) {
        if (byteBuffer.remaining() > 0) {
            byte b = byteBuffer.get();
            if (b == '\n') {
                return true;
            } else if (b == '\r') {
                if (byteBuffer.remaining() > 0 && byteBuffer.get() == '\n') {
                    return true;
                } else {
                    throw new StompConversionException("'\\r' must be followed by '\\n'");
                }
            }
            // Explicit cast for compatibility with covariant return type on JDK 9's ByteBuffer
            ((Buffer) byteBuffer).position(byteBuffer.position() - 1);
        }
        return false;
    }

}
